"""
## Voxel-wise encoding models for each subject while they view natural images.

## If you are running the same code with LaVCa, please skip it.

python -m BrainSCUBA.step1_encoding \
    --subject_name subj01 \
    --atlas cortex \
    --betas_norm \
    --modalities image \
    --modality_hparams default \
    --feat_names CLIP-ViT-B-32 \
    --select_layers all \
    --device 0 \
    --feature_space mono \
    --reg_type ridge \
    --perm_type gauss \
    --reduce_dim default None
"""

import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
from himalaya.backend import set_backend
from himalaya.scoring import correlation_score
from utils.utils import (
    TrnVal, gen_nulldistrib_gauss, gen_nulldistrib_block,
    fdr_correction, make_himalaya_pipeline, make_filename,
    collect_fmri_byroi_for_nsd, collect_stim_for_nsd
    )

print(torch.cuda.is_available())  # Trueが期待される
print(torch.version.cuda)
torch.backends.cuda.preferred_linalg_library("default")

def load_resp_wholevoxels_for_nsd(subject_name, dataset="all", atlas="streams", normalize=False):
    resp_trn = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="TRAIN",
                                                         atlasname=atlas, norm=normalize)
    resp_val = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="VALID",
                                                         atlasname=atlas, norm=normalize)

    return TrnVal(trn=resp_trn, val=resp_val)

def load_stim_for_nsd(subject, modality, feat_path, reduce_dim, dataset="all"):
    stim = collect_stim_for_nsd(subject, modality, feat_path, "ave",reduce_dim)
    return stim.trn, stim.va


def mono_regressor(
    stim_trn:np.ndarray,
    stim_val: np.ndarray,
    resp: TrnVal[np.ndarray],
    emb_name: str,
    device: int,
    perm_type: str,
    save_pred: bool,
) -> tuple[dict[str, np.ndarray], object]:
    """Train an encoder for mono feature space."""

    alphas = np.logspace(-4, 20, 25)
    
    x_trn, x_val = stim_trn.astype("float32"), stim_val.astype("float32")
    y_trn, y_val = resp.trn.astype("float32"), resp.val.astype("float32")
    # L2ノルムで正規化する関数
    def l2_normalize(x):
        # 各サンプルの L2 ノルムを計算
        norms = np.linalg.norm(x, axis=1, keepdims=True)
        # ノルムが 0 の場合は正規化せずそのまま返す
        norms[norms == 0] = 1
        return x / norms

    # x_trn, x_val を正規化
    x_trn = l2_normalize(x_trn)
    x_val = l2_normalize(x_val)

    # xのノルムを確認
    print("Norm of first x_trn sample after normalization:", np.linalg.norm(x_trn[0]))
    print("Norm of first x_val sample after normalization:", np.linalg.norm(x_val[0]))
    # yの平均と標準偏差を確認
    print("Mean of y_trn:", np.mean(y_trn))
    print("Std of y_trn:", np.std(y_trn))
    print("Mean of y_val:", np.mean(y_val))
    print("Std of y_val:", np.std(y_val))
    n_samples_val = x_val.shape[0]

    if device >= 0:
        if device < torch.cuda.device_count():
            torch.cuda.set_device(f"cuda:{device}")
            backend = set_backend("torch_cuda", on_error="warn")
            print("Running on GPU...")

        else:
            print("The CUDA device you specified is not available.")
            print("Running on CPU...")
            backend = set_backend("torch", on_error="warn")
    else:
        backend = set_backend("torch", on_error="warn")
        device = "cpu"
        print("Running on CPU...")


    pipeline = make_himalaya_pipeline(n_samples=x_trn.shape[0],
                                    n_features=x_trn.shape[1],
                                    cv=5,
                                    alpha=alphas,
                                    score_func=correlation_score)

    pipeline.fit(x_trn, y_trn)

    params = {}
    if  "ridgecv" in pipeline.named_steps:        
        cv_scores = backend.to_numpy(pipeline.named_steps["ridgecv"].cv_scores_)
        params['coef_'] = backend.to_numpy(pipeline.named_steps["ridgecv"].coef_)
    elif "kernelridgecv" in pipeline.named_steps:
        cv_scores = backend.to_numpy(pipeline.named_steps["kernelridgecv"].cv_scores_)
        params['coef_'] = backend.to_numpy(pipeline.named_steps["kernelridgecv"].get_primal_coef())

    cv_scores = backend.to_numpy(cv_scores)
    y_val_pred = pipeline.predict(x_val)
    score = correlation_score(y_val_pred, y_val)
    score = backend.to_numpy(score)
    print(f"Mean CV score: {cv_scores.mean()}")
    print(f"Mean score: {score.mean()}")

    if perm_type == "gauss":
        rccs = gen_nulldistrib_gauss(len(score), n_samples_val)
    elif perm_type == "block":
        rccs = gen_nulldistrib_block(y_val,
                                     y_val_pred,
                                     device = device)

    significant_voxels, pvalue_corrected = fdr_correction(score, rccs)

    fig = plt.figure()
    plt.hist(score, np.linspace(0, np.max(score), 100), alpha=1.0, label=emb_name)
    plt.title(r"Histogram of correlation coefficient score")
    plt.legend()

    return cv_scores, score, pvalue_corrected, fig, params


def main(args) -> None:
    print(f"Now processing: {args.subject_name}")
    print("Loading response data...")
    resp = load_resp_wholevoxels_for_nsd(args.subject_name, args.dataset, args.atlas, args.betas_norm)

    if len(args.modality_hparams) == 1:
        args.modality_hparams = [args.modality_hparams[0]]
    
    feat_name = args.feat_names[0]
    modality_name = args.modalities[0]
    modality_hparams_savename= '_'.join(map(str, args.modality_hparams))
    if "all" in args.select_layers:
        layers_to_process = [
            d for d in os.listdir(f"./data/stim_features/nsd/{modality_name}/{modality_hparams_savename}/{feat_name}") if "layer" in d
        ]
    else:
        layers_to_process = [f"layer{l}" for l in args.select_layers]
        
    for layer_name in layers_to_process:
        try:
            print(f"Loading {feat_name}'s {layer_name}...")
            layer_path = f"./data/stim_features/nsd/{modality_name}/{modality_hparams_savename}/{feat_name}/{layer_name}"
            scores_root = f"./data/nsd/encoding/{args.subject_name}/scores"
            scores_save_dir = f"{scores_root}/{modality_name}/{modality_hparams_savename}/{feat_name}/{layer_name}"
            os.makedirs(scores_save_dir, exist_ok=True)
            
            # Check if the other server is procesing the same layer. If so, skip this layer.
            temp_file = f"{scores_save_dir}/temp_encoding.txt"
            if os.path.exists(temp_file):
                print(f"temp file exists in {layer_name}. Skip this layer.")
                continue
            print(f"Now processing: {layer_name}")
            open(temp_file, 'a').close()
            
            # Check if the scores are already saved. If so, skip this layer.
            # if check_saved_score(scores_save_dir, args.reduce_dim, layer_path, args.dataset):
            #     continue
            stim_trn, stim_val = load_stim_for_nsd(args.subject_name, modality_name, layer_path, args.reduce_dim, args.dataset)
        
            print("Training...")

            cv_scores, scores, pvalue_corrected, fig, params = mono_regressor(
                stim_trn, stim_val, resp, feat_name, args.device, args.perm_type, args.save_pred
                )

            os.makedirs(scores_save_dir, exist_ok=True)
            filename = make_filename(args.reduce_dim, args.dataset)
            
            if args.betas_norm:
                filename += "_betanorm"

            np.save(f"{scores_save_dir}/cv_cc_{filename}.npy", cv_scores)
            np.save(f"{scores_save_dir}/cc_{filename}.npy", scores)
            np.save(f"{scores_save_dir}/pvalues_corrected_{filename}.npy", pvalue_corrected)
            np.save(f"{scores_save_dir}/coef_{filename}.npy", params['coef_'])
            fig.savefig(f"{scores_save_dir}/dist_{filename}.png")
        
        finally:
            try:
                os.remove(temp_file)
            except:
                pass


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train a decoding model to predict embedding from fmri data."
    )

    parser.add_argument(
        "--subject_name",
        type=str,
        required=True,
        help="Name of the subject to train the model on.",
    )
    
    parser.add_argument(
        "--atlas",
        type=str,
        required=True,
        help="Name of the atlas to use.",
    )
    parser.add_argument(
        "--betas_norm",
        action='store_true',
        help="Whether to normalize betas.",
    )
    parser.add_argument(
        "--modalities",
        nargs="*",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    
    parser.add_argument(
        "--modality_hparams",
        nargs="*",
        type=str,
        required=True,
        default="default",
        help="Specific modality's hparams."
    )

    parser.add_argument(
        "--feat_names",
        nargs="*",
        type=str,
        required=True,
        help="Names of the feature to use.",
    )
    
    parser.add_argument(
        "--select_layers",
        nargs="*",
        type=str,
        required=True,
        default="all",
        help="Number of layer to use. Set None if there is no layer.",
    )

    parser.add_argument(
        "--device",
        type=int,
        required=True,
        default=0,
        help="GPU number",
    )

    parser.add_argument(
        "--feature_space",
        choices=["mono", "multi"],
        required=True,
        help="Number of feature space.",
    )
    
    parser.add_argument(
        "--reg_type",
        choices=["ridge", "gradient"],
        default="ridge",
        help="Type of the regression.",
    )

    parser.add_argument(
        "--perm_type",
        choices=["gauss", "block"],
        required=True,
        help="Type of the significance testing.",
    )

    parser.add_argument(
        "--n_iter",
        type=int,
        default=100,
        required=False,
        help="Number of random search interations.",
    )
    
    parser.add_argument(
        "--reduce_dim",
        nargs="*",
        type=str,
        default = ["default", None],
        required=False,
        help="Dimension reduction method and its hyperparameter.",
    )

    parser.add_argument(
        "--save_pred",
        action='store_true',
        help="for saving cross-validated prediction of training dataset."
    )

    parser.add_argument(
        "--voxel_step",
        type=int,
        required=False,
        help="Number of voxels to process at once."
    )

    main(parser.parse_args())
